-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir] [Vector] Add IndexBitWidth option to vector-to-llvm pass #128154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
|
cc: @javedabsar1 |
|
@llvm/pr-subscribers-mlir Author: None (quic-rb10) ChangesThe VectorToLLVM pass currently includes an option (force32BitVectorIndices) to override vector indices. However, it lacks a mechanism to generically override the indexBitWidth. To address this, we are introducing a new indexBitWidth option for the VectorToLLVM pass, allowing users to specify the bit width of the index type. Patch is 62.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128154.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cccdf0a8518bf..20eb6392daf49 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1414,6 +1414,9 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"vector::VectorTransformsOptions",
/*default=*/"vector::VectorTransformsOptions()",
"Options to lower some operations like contractions and transposes.">,
+ Option<"indexBitwidth", "index-bitwidth", "unsigned",
+ /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
+ "Bitwidth of the index type, 0 to use size of machine word">,
];
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c9d637ce81f93..1f8a222282aac 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -49,10 +49,9 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
int64_t pos) {
assert(rank > 0 && "0-D vector corner case should have been handled already");
if (rank == 1) {
- auto idxType = rewriter.getIndexType();
+ auto idxType = typeConverter.convertType(rewriter.getIndexType());
auto constant = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter.convertType(idxType),
- rewriter.getIntegerAttr(idxType, pos));
+ loc, idxType, rewriter.getIntegerAttr(idxType, pos));
return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
constant);
}
@@ -64,10 +63,9 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter, Location loc,
Value val, Type llvmType, int64_t rank, int64_t pos) {
if (rank <= 1) {
- auto idxType = rewriter.getIndexType();
+ auto idxType = typeConverter.convertType(rewriter.getIndexType());
auto constant = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter.convertType(idxType),
- rewriter.getIntegerAttr(idxType, pos));
+ loc, idxType, rewriter.getIntegerAttr(idxType, pos));
return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
constant);
}
@@ -1064,10 +1062,9 @@ class VectorExtractElementOpConversion
if (vectorType.getRank() == 0) {
Location loc = extractEltOp.getLoc();
- auto idxType = rewriter.getIndexType();
+ auto idxType = typeConverter->convertType(rewriter.getIndexType());
auto zero = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
+ loc, idxType, rewriter.getIntegerAttr(idxType, 0));
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
extractEltOp, llvmType, adaptor.getVector(), zero);
return success();
@@ -1198,10 +1195,9 @@ class VectorInsertElementOpConversion
if (vectorType.getRank() == 0) {
Location loc = insertEltOp.getLoc();
- auto idxType = rewriter.getIndexType();
+ auto idxType = typeConverter->convertType(rewriter.getIndexType());
auto zero = rewriter.create<LLVM::ConstantOp>(
- loc, typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
+ loc, idxType, rewriter.getIntegerAttr(idxType, 0));
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
return success();
@@ -1439,8 +1435,6 @@ class VectorTypeCastOpConversion
if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
return failure();
- auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
-
// Create descriptor.
auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
// Set allocated ptr.
@@ -1451,21 +1445,24 @@ class VectorTypeCastOpConversion
Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
desc.setAlignedPtr(rewriter, loc, ptr);
// Fill offset 0.
- auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
- auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
+
+ auto idxType = typeConverter->convertType(rewriter.getIndexType());
+ auto zero = rewriter.create<LLVM::ConstantOp>(
+ loc, idxType, rewriter.getIntegerAttr(idxType, 0));
desc.setOffset(rewriter, loc, zero);
// Fill size and stride descriptors in memref.
for (const auto &indexedSize :
llvm::enumerate(targetMemRefType.getShape())) {
int64_t index = indexedSize.index();
- auto sizeAttr =
- rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
- auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
+
+ auto size = rewriter.create<LLVM::ConstantOp>(
+ loc, idxType, rewriter.getIntegerAttr(idxType, indexedSize.value()));
desc.setSize(rewriter, loc, index, size);
- auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
- (*targetStrides)[index]);
- auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
+
+ auto stride = rewriter.create<LLVM::ConstantOp>(
+ loc, idxType,
+ rewriter.getIntegerAttr(idxType, (*targetStrides)[index]));
desc.setStride(rewriter, loc, index, stride);
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index e3a81bd20212d..c9b6a528c03b6 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -86,6 +86,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
// Convert to the LLVM IR dialect.
LowerToLLVMOptions options(&getContext());
+ if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
+ options.overrideIndexBitwidth(indexBitwidth);
LLVMTypeConverter converter(&getContext(), options);
RewritePatternSet patterns(&getContext());
populateVectorTransferLoweringPatterns(patterns);
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-index-bitwidth.mlir b/mlir/test/Conversion/VectorToLLVM/vector-index-bitwidth.mlir
new file mode 100644
index 0000000000000..0869cd28b29b2
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/vector-index-bitwidth.mlir
@@ -0,0 +1,674 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm='index-bitwidth=32' -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @masked_reduce_add_f32_scalable(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<[16]xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<[16]xi1>) -> f32 {
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK: %[[VAL_4:.*]] = "llvm.intr.vscale"() : () -> i32
+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : i32 to index
+// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : index to i32
+// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_3]], %[[VAL_6]] : i32
+// CHECK: %[[VAL_8:.*]] = "llvm.intr.vp.reduce.fadd"(%[[VAL_2]], %[[VAL_0]], %[[VAL_1]], %[[VAL_7]]) : (f32, vector<[16]xf32>, vector<[16]xi1>, i32) -> f32
+// CHECK: return %[[VAL_8]] : f32
+// CHECK: }
+func.func @masked_reduce_add_f32_scalable(%arg0: vector<[16]xf32>, %mask : vector<[16]xi1>) -> f32 {
+ %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<[16]xf32> into f32 } : vector<[16]xi1> -> f32
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @masked_reduce_minf_f32_scalable(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<[16]xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<[16]xi1>) -> f32 {
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0xFFC00000 : f32) : f32
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK: %[[VAL_4:.*]] = "llvm.intr.vscale"() : () -> i32
+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : i32 to index
+// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : index to i32
+// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_3]], %[[VAL_6]] : i32
+// CHECK: %[[VAL_8:.*]] = "llvm.intr.vp.reduce.fmin"(%[[VAL_2]], %[[VAL_0]], %[[VAL_1]], %[[VAL_7]]) : (f32, vector<[16]xf32>, vector<[16]xi1>, i32) -> f32
+// CHECK: return %[[VAL_8]] : f32
+// CHECK: }
+func.func @masked_reduce_minf_f32_scalable(%arg0: vector<[16]xf32>, %mask : vector<[16]xi1>) -> f32 {
+ %0 = vector.mask %mask { vector.reduction <minnumf>, %arg0 : vector<[16]xf32> into f32 } : vector<[16]xi1> -> f32
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @masked_reduce_add_i8_scalable(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<[32]xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<[32]xi1>) -> i8 {
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i8) : i8
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: %[[VAL_4:.*]] = "llvm.intr.vscale"() : () -> i32
+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : i32 to index
+// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : index to i32
+// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_3]], %[[VAL_6]] : i32
+// CHECK: %[[VAL_8:.*]] = "llvm.intr.vp.reduce.add"(%[[VAL_2]], %[[VAL_0]], %[[VAL_1]], %[[VAL_7]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
+// CHECK: return %[[VAL_8]] : i8
+// CHECK: }
+func.func @masked_reduce_add_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
+ %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
+ return %0 : i8
+}
+
+
+// -----
+
+// CHECK-LABEL: func.func @masked_reduce_minui_i8_scalable(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<[32]xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<[32]xi1>) -> i8 {
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(-1 : i8) : i8
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: %[[VAL_4:.*]] = "llvm.intr.vscale"() : () -> i32
+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : i32 to index
+// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : index to i32
+// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_3]], %[[VAL_6]] : i32
+// CHECK: %[[VAL_8:.*]] = "llvm.intr.vp.reduce.umin"(%[[VAL_2]], %[[VAL_0]], %[[VAL_1]], %[[VAL_7]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
+// CHECK: return %[[VAL_8]] : i8
+// CHECK: }
+func.func @masked_reduce_minui_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
+ %0 = vector.mask %mask { vector.reduction <minui>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
+ return %0 : i8
+}
+
+// -----
+
+// CHECK-LABEL: func.func @masked_reduce_maxsi_i8_scalable(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<[32]xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<[32]xi1>) -> i8 {
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(-128 : i8) : i8
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: %[[VAL_4:.*]] = "llvm.intr.vscale"() : () -> i32
+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : i32 to index
+// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : index to i32
+// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_3]], %[[VAL_6]] : i32
+// CHECK: %[[VAL_8:.*]] = "llvm.intr.vp.reduce.smax"(%[[VAL_2]], %[[VAL_0]], %[[VAL_1]], %[[VAL_7]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
+// CHECK: return %[[VAL_8]] : i8
+// CHECK: }
+func.func @masked_reduce_maxsi_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
+ %0 = vector.mask %mask { vector.reduction <maxsi>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
+ return %0 : i8
+}
+
+// -----
+
+// CHECK-LABEL: func.func @masked_reduce_xor_i8_scalable(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<[32]xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<[32]xi1>) -> i8 {
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i8) : i8
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: %[[VAL_4:.*]] = "llvm.intr.vscale"() : () -> i32
+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : i32 to index
+// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : index to i32
+// CHECK: %[[VAL_7:.*]] = arith.muli %[[VAL_3]], %[[VAL_6]] : i32
+// CHECK: %[[VAL_8:.*]] = "llvm.intr.vp.reduce.xor"(%[[VAL_2]], %[[VAL_0]], %[[VAL_1]], %[[VAL_7]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8
+// CHECK: return %[[VAL_8]] : i8
+// CHECK: }
+func.func @masked_reduce_xor_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
+ %0 = vector.mask %mask { vector.reduction <xor>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
+ return %0 : i8
+}
+
+// -----
+
+// CHECK-LABEL: func.func @shuffle_1D(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<3xf32>) -> vector<5xf32> {
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.poison : vector<5xf32>
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_4:.*]] = llvm.extractelement %[[VAL_1]]{{\[}}%[[VAL_3]] : i32] : vector<3xf32>
+// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_6:.*]] = llvm.insertelement %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_5]] : i32] : vector<5xf32>
+// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_8:.*]] = llvm.extractelement %[[VAL_1]]{{\[}}%[[VAL_7]] : i32] : vector<3xf32>
+// CHECK: %[[VAL_9:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_10:.*]] = llvm.insertelement %[[VAL_8]], %[[VAL_6]]{{\[}}%[[VAL_9]] : i32] : vector<5xf32>
+// CHECK: %[[VAL_11:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_12:.*]] = llvm.extractelement %[[VAL_1]]{{\[}}%[[VAL_11]] : i32] : vector<3xf32>
+// CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[VAL_14:.*]] = llvm.insertelement %[[VAL_12]], %[[VAL_10]]{{\[}}%[[VAL_13]] : i32] : vector<5xf32>
+// CHECK: %[[VAL_15:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_16:.*]] = llvm.extractelement %[[VAL_0]]{{\[}}%[[VAL_15]] : i32] : vector<2xf32>
+// CHECK: %[[VAL_17:.*]] = llvm.mlir.constant(3 : i32) : i32
+// CHECK: %[[VAL_18:.*]] = llvm.insertelement %[[VAL_16]], %[[VAL_14]]{{\[}}%[[VAL_17]] : i32] : vector<5xf32>
+// CHECK: %[[VAL_19:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_20:.*]] = llvm.extractelement %[[VAL_0]]{{\[}}%[[VAL_19]] : i32] : vector<2xf32>
+// CHECK: %[[VAL_21:.*]] = llvm.mlir.constant(4 : i32) : i32
+// CHECK: %[[VAL_22:.*]] = llvm.insertelement %[[VAL_20]], %[[VAL_18]]{{\[}}%[[VAL_21]] : i32] : vector<5xf32>
+// CHECK: return %[[VAL_22]] : vector<5xf32>
+// CHECK: }
+func.func @shuffle_1D(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<5xf32> {
+ %1 = vector.shuffle %arg0, %arg1 [4, 3, 2, 1, 0] : vector<2xf32>, vector<3xf32>
+ return %1 : vector<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extractelement_from_vec_0d_f32(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<f32>) -> f32 {
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<f32> to vector<1xf32>
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_3:.*]] = llvm.extractelement %[[VAL_1]]{{\[}}%[[VAL_2]] : i32] : vector<1xf32>
+// CHECK: return %[[VAL_3]] : f32
+// CHECK: }
+func.func @extractelement_from_vec_0d_f32(%arg0: vector<f32>) -> f32 {
+ %1 = vector.extractelement %arg0[] : vector<f32>
+ return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @insertelement_into_vec_0d_f32(
+// CHECK-SAME: %[[VAL_0:.*]]: f32,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<f32>) -> vector<f32> {
+// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<f32> to vector<1xf32>
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_4:.*]] = llvm.insertelement %[[VAL_0]], %[[VAL_2]]{{\[}}%[[VAL_3]] : i32] : vector<1xf32>
+// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : vector<1xf32> to vector<f32>
+// CHECK: return %[[VAL_5]] : vector<f32>
+// CHECK: }
+func.func @insertelement_into_vec_0d_f32(%arg0: f32, %arg1: vector<f32>) -> vector<f32> {
+ %1 = vector.insertelement %arg0, %arg1[] : vector<f32>
+ return %1 : vector<f32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @type_cast_f32(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : memref<8x8x8xf32> to !llvm.struct<(ptr, ptr, i32, array<3 x i32>, array<3 x i32>)>
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i32)>
+// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.struct<(ptr, ptr, i32, array<3 x i32>, array<3 x i32>)>
+// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_2]][0] : !llvm.struct<(ptr, ptr, i32)>
+// CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_1]][1] : !llvm.struct<(ptr, ptr, i32, array<3 x i32>, array<3 x i32>)>
+// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_4]][1] : !llvm.struct<(ptr, ptr, i32)>
+// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_6]][2] : !llvm.struct<(ptr, ptr, i32)>
+// CHECK: %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]] : !llvm.struct<(ptr, ptr, i32)> to memref<vector<8x8x8xf32>>
+// CHECK: return %[[VAL_9]] : memref<vector<8x8x8xf32>>
+// CHECK: }
+func.func @type_cast_f32(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
+ %0 = vector.type_cast %arg0: memref<8x8x8xf32> to memref<vector<8x8x8xf32>>
+ return %0 : memref<vector<8x8x8xf32>>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @type_cast_non_zero_addrspace(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<8x8x8xf32, 3>) -> memref<vector<8x8x8xf32>, 3> {
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : memref<8x8x8xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i32, array<3 x i32>, array<3 x i32>)>
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.poison : !llvm.struct<(ptr<3>, ptr<3>, i32)>
+// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<3 x i32>, array<3 x i32>)>
+// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_2]][0] : !llvm.struct<(ptr<3>, ptr<3>, i32)>
+// CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_1]][1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<3 x i32>, array<3 x i32>)>
+// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL...
[truncated]
|
| /*default=*/"vector::VectorTransformsOptions()", | ||
| "Options to lower some operations like contractions and transposes.">, | ||
| Option<"indexBitwidth", "index-bitwidth", "unsigned", | ||
| /*default=kDeriveIndexBitwidthFromDataLayout*/"0", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! How this derive index bitwidth from data layout works? Is there a single bitwidth per module? It would be great if we could extend this to have a bitwidth per address space at some point :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dcaballe Could you please elaborate on this? I have added some of my findings regarding the indexbitwidth from datalayout as a reply to another one of your comments.
102fcf0 to
ae8d205
Compare
Change-Id: I1ad6f77183f1f1faf25e935131de4ef3a4334150
ae8d205 to
45d715a
Compare
The VectorToLLVM pass currently includes an option (force32BitVectorIndices) to override vector indices. However, it lacks a mechanism to generically override the indexBitWidth. To address this, we are introducing a new indexBitWidth option for the VectorToLLVM pass, allowing users to specify the bit width of the index type.